package com.hanxiaozhang.unionfindset;

import java.util.HashMap;
import java.util.List;
import java.util.Stack;

/**
 * 〈一句话功能简述〉<br>
 * 〈并查集〉
 *
 * @author hanxinghua
 * @create 2021/9/24
 * @since 1.0.0
 */
public class UnionFindSet<V> {

    /**
     * 节点集合，key -> value v ->Node
     */
    public HashMap<V, Node<V>> nodes;

    /**
     * 节点与父亲的节点关系
     */
    public HashMap<Node<V>, Node<V>> parents;

    /**
     * 只有一个点，他是代表点，才有记录
     */
    public HashMap<Node<V>, Integer> sizeMap;


    /**
     * 初始化代表点
     *
     * @param values
     */
    public UnionFindSet(List<V> values) {
        nodes = new HashMap<>();
        parents = new HashMap<>();
        sizeMap = new HashMap<>();

        for (V value : values) {
            Node<V> node = new Node<>(value);
            nodes.put(value, node);
            parents.put(node, node);
            sizeMap.put(node, 1);
        }
    }

    /**
     * 从cur开始，一直往上找，找到不能往上的代表点，返回
     *
     * @param cur
     * @return
     */
    public Node<V> findFather(Node<V> cur) {
        // 把经过节点记录到stack
        Stack<Node<V>> path = new Stack<>();
        while (cur != parents.get(cur)) {
            path.push(cur);
            cur = parents.get(cur);
        }
        // 把经过节点都指向代表点
        while (!path.isEmpty()) {
            parents.put(path.pop(), cur);
        }
        return cur;
    }

    /**
     * 查询样本x和样本y是否属于一个集合
     *
     * @param x
     * @param y
     * @return
     */
    public boolean isSameSet(V x, V y) {
        if (!nodes.containsKey(x) || !nodes.containsKey(y)) {
            return false;
        }
        return findFather(nodes.get(x)) == findFather(nodes.get(y));
    }

    /**
     * 把x和y各自所在集合的所有样本合并成一个集合
     *
     * @param x
     * @param y
     */
    public void union(V x, V y) {
        if (!nodes.containsKey(x) || !nodes.containsKey(y)) {
            return;
        }
        Node<V> xHead = findFather(nodes.get(x));
        Node<V> yHead = findFather(nodes.get(y));
        if (xHead != yHead) {
            int xSetSize = sizeMap.get(xHead);
            int ySetSize = sizeMap.get(yHead);
            if (xSetSize >= ySetSize) {
                parents.put(yHead, xHead);
                sizeMap.put(xHead, xSetSize + ySetSize);
                sizeMap.remove(yHead);
            } else {
                parents.put(xHead, yHead);
                sizeMap.put(yHead, xSetSize + ySetSize);
                sizeMap.remove(xHead);
            }
        }
    }


    public int size() {
        return sizeMap.size();
    }

}
